#%%
import numpy as np
emb = np.load('log/20201111015958_2acfeCora_Cora_T/train_epoch_em400.pngemb.npy')
lab = np.load('log/20201111015958_2acfeCora_Cora_T/train_epoch_em400.pnglab.npy')

emb.shape

# %%
import matplotlib.pyplot as plt
import umap


for i in [100]:
    for j in [0.001]:
        print(i, j)
        plt.figure(figsize=(10,10))
        umapt = umap.UMAP(n_components=2, n_neighbors=i, min_dist=j,)
        data_em = umapt.fit_transform(emb)

        plt.scatter(
            data_em[:, 0],
            data_em[:, 1],
            c=lab,
            s=4,
            cmap='rainbow'
            )
        plt.savefig('{}_{}.png'.format(i,j))
        plt.close()
# %%
emb.shape
# %%
